{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import collections\n",
    "import logging\n",
    "import json\n",
    "import re\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "\n",
    "from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
    "from pytorch_pretrained_bert.modeling import BertModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer_indexes = [-1,-2]\n",
    "pretrain_params = \"bert-base-uncased\"\n",
    "tokenizer = BertTokenizer.from_pretrained(pretrain_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BertModel.from_pretrained(pretrain_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.to(device)\n",
    "# model.eval()\n",
    "# if n_gpu > 1:\n",
    "#     model = torch.nn.DataParallel(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sentence_root = '/p300/flickr30k_images/flickr30k_anno/sent_anno.json'\n",
    "# sent_anno = json.load(open(sentence_root, 'r'))\n",
    "# MAX_LEN = 88 # 86+2\n",
    "# max_len = 0\n",
    "# ret = {}\n",
    "# for im_id, v in sent_anno.items():\n",
    "#     tokens = []\n",
    "#     for sent in v:\n",
    "#         seq = sent['sentence']\n",
    "#         seq = \"[CLS] \" + seq + \" [SEP]\"\n",
    "#         tokenized_text = tokenizer.tokenize(seq)\n",
    "#         if len(tokenized_text) <= MAX_LEN:\n",
    "#             tokenized_text += ['[PAD]' for _ in range(MAX_LEN - len(tokenized_text))]\n",
    "#         indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "#         tokens_tensor = torch.tensor(indexed_tokens).cuda()\n",
    "#         tokens.append(tokens_tensor)\n",
    "#     tokens_tensor = torch.stack(tokens)\n",
    "    \n",
    "#     with torch.no_grad():\n",
    "#         encoded_layers, _ = model(tokens_tensor, None)\n",
    "#     ret[im_id] = encoded_layers[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist(vec1, vec2):\n",
    "    return (vec1*vec2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "seqs = ['man', 'woman', 'kid', 'son', 'mother', 'cat', 'dog', 'duck', 'tree', 'flower', 'building']\n",
    "embeds = []\n",
    "for seq in seqs:\n",
    "    seq = \"[CLS] \" + seq + \" [SEP]\"\n",
    "    tokenized_text = tokenizer.tokenize(seq)\n",
    "    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "    tokens_tensor = torch.tensor(indexed_tokens)\n",
    "    tokens_tensor = torch.stack([tokens_tensor])\n",
    "    encoded_layers, _ = model(tokens_tensor, None)\n",
    "    embeds.append(encoded_layers[-1][0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.6691, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(dist(embeds[0], embeds[-3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.6236, grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(dist(embeds[0], embeds[3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skipthoughts import UniSkip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/venv/pose/lib/python3.6/site-packages/torch/nn/modules/rnn.py:54: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.25 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n"
     ]
    }
   ],
   "source": [
    "dir_st = 'data/skip-thoughts'\n",
    "vocab = ['A', 'lot', 'of', 'beautiful', 'man', 'woman', 'kid', 'son', 'mother', 'cat', \n",
    "         'dog', 'duck', 'tree', 'flower', 'building']\n",
    "uniskip = UniSkip(dir_st, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.LongTensor([\n",
    "    [1,5,0,0],\n",
    "    [1,2,3,6],\n",
    "    [1,10,0,0],\n",
    "    [1,11,0,0],\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeds = uniskip(input, lengths=[2,4,2,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(38.1826, grad_fn=<SumBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(dist(embeds[0], embeds[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(36.2752, grad_fn=<SumBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(dist(embeds[0], embeds[2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(39.9834, grad_fn=<SumBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(dist(embeds[2], embeds[3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] A man walked across a street holding a basket on top of his head filled with fruit and vegetables with one hand , while holding bags of grapes in his right hand . [SEP]\n"
     ]
    }
   ],
   "source": [
    "seq = 'A man walked across a street holding a basket on top of his head filled with fruit and vegetables with one hand , while holding bags of grapes in his right hand .'\n",
    "seq = \"[CLS] \" + seq + \" [SEP]\"\n",
    "print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]', 'a', 'man', 'walked', 'across', 'a', 'street', 'holding', 'a', 'basket', 'on', 'top', 'of', 'his', 'head', 'filled', 'with', 'fruit', 'and', 'vegetables', 'with', 'one', 'hand', ',', 'while', 'holding', 'bags', 'of', 'grapes', 'in', 'his', 'right', 'hand', '.', '[SEP]']\n"
     ]
    }
   ],
   "source": [
    "tokenized_text = tokenizer.tokenize(seq)\n",
    "print(tokenized_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[101, 1037, 2158, 2939, 2408, 1037, 2395, 3173, 1037, 10810, 2006, 2327, 1997, 2010, 2132, 3561, 2007, 5909, 1998, 11546, 2007, 2028, 2192, 1010, 2096, 3173, 8641, 1997, 16575, 1999, 2010, 2157, 2192, 1012, 102]\n",
      "35\n"
     ]
    }
   ],
   "source": [
    "indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "print(indexed_tokens)\n",
    "print(len(indexed_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_tensor = torch.tensor(indexed_tokens)\n",
    "tokens_tensor = torch.stack([tokens_tensor])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_layers, _ = model(tokens_tensor, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 35, 768])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_layers[-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pose",
   "language": "python",
   "name": "pose"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
